{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "712f9638",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "Named Entity Recognition (NER) is the process of identifying named entities in text. Example of named entities are: \"Person\", \"Location\", \"Organization\", \"Dates\" etc. NER is essentially(本质上) a token classification task where every token is classified into one or more predetermined(预定的) categories.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4db4588c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "--2022-12-21 20:49:39--  https://raw.githubusercontent.com/sighsmile/conlleval/master/conlleval.py\n",
      "Resolving raw.githubusercontent.com... 185.199.108.133, 185.199.111.133, 185.199.110.133, ...\n",
      "Connecting to raw.githubusercontent.com|185.199.108.133|:443... connected.\n",
      "Unable to establish SSL connection.\n"
     ]
    }
   ],
   "source": [
    "# !wget https://raw.githubusercontent.com/sighsmile/conlleval/master/conlleval.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bc8b1b3b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from datasets import load_dataset\n",
    "from collections import Counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9fa82894",
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import division, print_function, unicode_literals\n",
    "\n",
    "import sys\n",
    "from collections import defaultdict\n",
    "\n",
    "def split_tag(chunk_tag):\n",
    "    if chunk_tag == 'O':\n",
    "        return ('O', None)\n",
    "    return chunk_tag.split('-', maxsplit=1)\n",
    "\n",
    "def is_chunk_end(prev_tag, tag):\n",
    "    prefix1, chunk_type1 = split_tag(prev_tag)\n",
    "    prefix2, chunk_type2 = split_tag(tag)\n",
    "\n",
    "    if prefix1 == 'O':\n",
    "        return False\n",
    "    if prefix2 == 'O':\n",
    "        return prefix1 != 'O'\n",
    "\n",
    "    if chunk_type1 != chunk_type2:\n",
    "        return True\n",
    "\n",
    "    return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S']\n",
    "\n",
    "def is_chunk_start(prev_tag, tag):\n",
    "    prefix1, chunk_type1 = split_tag(prev_tag)\n",
    "    prefix2, chunk_type2 = split_tag(tag)\n",
    "\n",
    "    if prefix2 == 'O':\n",
    "        return False\n",
    "    if prefix1 == 'O':\n",
    "        return prefix2 != 'O'\n",
    "\n",
    "    if chunk_type1 != chunk_type2:\n",
    "        return True\n",
    "\n",
    "    return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S']\n",
    "\n",
    "\n",
    "def calc_metrics(tp, p, t, percent=True):\n",
    "    precision = tp / p if p else 0\n",
    "    recall = tp / t if t else 0\n",
    "    fb1 = 2 * precision * recall / (precision + recall) if precision + recall else 0\n",
    "    if percent:\n",
    "        return 100 * precision, 100 * recall, 100 * fb1\n",
    "    else:\n",
    "        return precision, recall, fb1\n",
    "\n",
    "\n",
    "def count_chunks(true_seqs, pred_seqs):\n",
    "    \"\"\"\n",
    "    true_seqs: a list of true tags\n",
    "    pred_seqs: a list of predicted tags\n",
    "\n",
    "    return: \n",
    "    correct_chunks: a dict (counter), \n",
    "                    key = chunk types, \n",
    "                    value = number of correctly identified chunks per type\n",
    "    true_chunks:    a dict, number of true chunks per type\n",
    "    pred_chunks:    a dict, number of identified chunks per type\n",
    "\n",
    "    correct_counts, true_counts, pred_counts: similar to above, but for tags\n",
    "    \"\"\"\n",
    "    correct_chunks = defaultdict(int)\n",
    "    true_chunks = defaultdict(int)\n",
    "    pred_chunks = defaultdict(int)\n",
    "\n",
    "    correct_counts = defaultdict(int)\n",
    "    true_counts = defaultdict(int)\n",
    "    pred_counts = defaultdict(int)\n",
    "\n",
    "    prev_true_tag, prev_pred_tag = 'O', 'O'\n",
    "    correct_chunk = None\n",
    "\n",
    "    for true_tag, pred_tag in zip(true_seqs, pred_seqs):\n",
    "        if true_tag == pred_tag:\n",
    "            correct_counts[true_tag] += 1\n",
    "        true_counts[true_tag] += 1\n",
    "        pred_counts[pred_tag] += 1\n",
    "\n",
    "        _, true_type = split_tag(true_tag)\n",
    "        _, pred_type = split_tag(pred_tag)\n",
    "\n",
    "        if correct_chunk is not None:\n",
    "            true_end = is_chunk_end(prev_true_tag, true_tag)\n",
    "            pred_end = is_chunk_end(prev_pred_tag, pred_tag)\n",
    "\n",
    "            if pred_end and true_end:\n",
    "                correct_chunks[correct_chunk] += 1\n",
    "                correct_chunk = None\n",
    "            elif pred_end != true_end or true_type != pred_type:\n",
    "                correct_chunk = None\n",
    "\n",
    "        true_start = is_chunk_start(prev_true_tag, true_tag)\n",
    "        pred_start = is_chunk_start(prev_pred_tag, pred_tag)\n",
    "\n",
    "        if true_start and pred_start and true_type == pred_type:\n",
    "            correct_chunk = true_type\n",
    "        if true_start:\n",
    "            true_chunks[true_type] += 1\n",
    "        if pred_start:\n",
    "            pred_chunks[pred_type] += 1\n",
    "\n",
    "        prev_true_tag, prev_pred_tag = true_tag, pred_tag\n",
    "    if correct_chunk is not None:\n",
    "        correct_chunks[correct_chunk] += 1\n",
    "\n",
    "    return (correct_chunks, true_chunks, pred_chunks, \n",
    "        correct_counts, true_counts, pred_counts)\n",
    "\n",
    "def get_result(correct_chunks, true_chunks, pred_chunks,\n",
    "    correct_counts, true_counts, pred_counts, verbose=True):\n",
    "    \"\"\"\n",
    "    if verbose, print overall performance, as well as preformance per chunk type;\n",
    "    otherwise, simply return overall prec, rec, f1 scores\n",
    "    \"\"\"\n",
    "    # sum counts\n",
    "    sum_correct_chunks = sum(correct_chunks.values())\n",
    "    sum_true_chunks = sum(true_chunks.values())\n",
    "    sum_pred_chunks = sum(pred_chunks.values())\n",
    "\n",
    "    sum_correct_counts = sum(correct_counts.values())\n",
    "    sum_true_counts = sum(true_counts.values())\n",
    "\n",
    "    nonO_correct_counts = sum(v for k, v in correct_counts.items() if k != 'O')\n",
    "    nonO_true_counts = sum(v for k, v in true_counts.items() if k != 'O')\n",
    "\n",
    "    chunk_types = sorted(list(set(list(true_chunks) + list(pred_chunks))))\n",
    "\n",
    "    # compute overall precision, recall and FB1 (default values are 0.0)\n",
    "    prec, rec, f1 = calc_metrics(sum_correct_chunks, sum_pred_chunks, sum_true_chunks)\n",
    "    res = (prec, rec, f1)\n",
    "    if not verbose:\n",
    "        return res\n",
    "\n",
    "    # print overall performance, and performance per chunk type\n",
    "    \n",
    "    print(\"processed %i tokens with %i phrases; \" % (sum_true_counts, sum_true_chunks), end='')\n",
    "    print(\"found: %i phrases; correct: %i.\\n\" % (sum_pred_chunks, sum_correct_chunks), end='')\n",
    "        \n",
    "    print(\"accuracy: %6.2f%%; (non-O)\" % (100*nonO_correct_counts/nonO_true_counts))\n",
    "    print(\"accuracy: %6.2f%%; \" % (100*sum_correct_counts/sum_true_counts), end='')\n",
    "    print(\"precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f\" % (prec, rec, f1))\n",
    "\n",
    "    # for each chunk type, compute precision, recall and FB1 (default values are 0.0)\n",
    "    for t in chunk_types:\n",
    "        prec, rec, f1 = calc_metrics(correct_chunks[t], pred_chunks[t], true_chunks[t])\n",
    "        print(\"%17s: \" %t , end='')\n",
    "        print(\"precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f\" %\n",
    "                    (prec, rec, f1), end='')\n",
    "        print(\"  %d\" % pred_chunks[t])\n",
    "\n",
    "    return res\n",
    "    # you can generate LaTeX output for tables like in\n",
    "    # http://cnts.uia.ac.be/conll2003/ner/example.tex\n",
    "    # but I'm not implementing this\n",
    "\n",
    "def evaluate(true_seqs, pred_seqs, verbose=True):\n",
    "    (correct_chunks, true_chunks, pred_chunks,\n",
    "        correct_counts, true_counts, pred_counts) = count_chunks(true_seqs, pred_seqs)\n",
    "    result = get_result(correct_chunks, true_chunks, pred_chunks,\n",
    "        correct_counts, true_counts, pred_counts, verbose=verbose)\n",
    "    return result\n",
    "\n",
    "def evaluate_conll_file(fileIterator):\n",
    "    true_seqs, pred_seqs = [], []\n",
    "    \n",
    "    for line in fileIterator:\n",
    "        cols = line.strip().split()\n",
    "        # each non-empty line must contain >= 3 columns\n",
    "        if not cols:\n",
    "            true_seqs.append('O')\n",
    "            pred_seqs.append('O')\n",
    "        elif len(cols) < 3:\n",
    "            raise IOError(\"conlleval: too few columns in line %s\\n\" % line)\n",
    "        else:\n",
    "            # extract tags from last 2 columns\n",
    "            true_seqs.append(cols[-2])\n",
    "            pred_seqs.append(cols[-1])\n",
    "    return evaluate(true_seqs, pred_seqs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3208e983",
   "metadata": {},
   "source": [
    "定义一个transformer block层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bf5edba3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerBlock(layers.Layer):\n",
    "    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):\n",
    "        super(TransformerBlock, self).__init__()\n",
    "        self.att = keras.layers.MultiHeadAttention(\n",
    "            num_heads=num_heads, key_dim=embed_dim\n",
    "        )\n",
    "        self.ffn = keras.Sequential(\n",
    "            [\n",
    "                keras.layers.Dense(ff_dim, activation=\"relu\"),\n",
    "                keras.layers.Dense(embed_dim),\n",
    "            ]\n",
    "        )\n",
    "        self.layernorm1 = keras.layers.LayerNormalization(epsilon=1e-6)\n",
    "        self.layernorm2 = keras.layers.LayerNormalization(epsilon=1e-6)\n",
    "        self.dropout1 = keras.layers.Dropout(rate)\n",
    "        self.dropout2 = keras.layers.Dropout(rate)\n",
    "\n",
    "    def call(self, inputs, training=False):\n",
    "        attn_output = self.att(inputs, inputs)\n",
    "        attn_output = self.dropout1(attn_output, training=training)\n",
    "        out1 = self.layernorm1(inputs + attn_output)\n",
    "        ffn_output = self.ffn(out1)\n",
    "        ffn_output = self.dropout2(ffn_output, training=training)\n",
    "        return self.layernorm2(out1 + ffn_output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc350ca7",
   "metadata": {},
   "source": [
    "定义TokenAndPositionEmbedding位置信息层:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "823f89a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TokenAndPositionEmbedding(layers.Layer):\n",
    "    def __init__(self, maxlen, vocab_size, embed_dim):\n",
    "        super(TokenAndPositionEmbedding, self).__init__()\n",
    "        self.token_emb = keras.layers.Embedding(\n",
    "            input_dim=vocab_size, output_dim=embed_dim\n",
    "        )\n",
    "        self.pos_emb = keras.layers.Embedding(input_dim=maxlen, output_dim=embed_dim)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        maxlen = tf.shape(inputs)[-1]\n",
    "        positions = tf.range(start=0, limit=maxlen, delta=1)\n",
    "        position_embeddings = self.pos_emb(positions)\n",
    "        token_embeddings = self.token_emb(inputs)\n",
    "        return token_embeddings + position_embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d6ed5911",
   "metadata": {},
   "outputs": [],
   "source": [
    "class NERModel(keras.Model):\n",
    "    def __init__(\n",
    "        self, num_tags, vocab_size, maxlen=128, embed_dim=32, num_heads=2, ff_dim=32\n",
    "    ):\n",
    "        super(NERModel, self).__init__()\n",
    "        self.embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)\n",
    "        self.transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim)\n",
    "        self.dropout1 = layers.Dropout(0.1)\n",
    "        self.ff = layers.Dense(ff_dim, activation=\"relu\")\n",
    "        self.dropout2 = layers.Dropout(0.1)\n",
    "        self.ff_final = layers.Dense(num_tags, activation=\"softmax\")\n",
    "\n",
    "    def call(self, inputs, training=False):\n",
    "        x = self.embedding_layer(inputs)\n",
    "        x = self.transformer_block(x)\n",
    "        x = self.dropout1(x, training=training)\n",
    "        x = self.ff(x)\n",
    "        x = self.dropout2(x, training=training)\n",
    "        x = self.ff_final(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94f02f32",
   "metadata": {},
   "source": [
    "加载CoNLL 2003 dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fa002e09",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset conll2003 (C:/Users/wulin/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fbc8a5deee1e46f5bc487db773f8a528",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "conll_data = load_dataset(\"conll2003\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56e02800",
   "metadata": {},
   "source": [
    "将此数据导出为制表符分隔的文件格式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "757096a1",
   "metadata": {},
   "outputs": [
    {
     "ename": "FileExistsError",
     "evalue": "[WinError 183] 当文件已存在时，无法创建该文件。: 'data'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mFileExistsError\u001b[0m                           Traceback (most recent call last)",
      "Cell \u001b[1;32mIn [7], line 17\u001b[0m\n\u001b[0;32m      6\u001b[0m             \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(tokens) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m      7\u001b[0m                 f\u001b[38;5;241m.\u001b[39mwrite(\n\u001b[0;32m      8\u001b[0m                     \u001b[38;5;28mstr\u001b[39m(\u001b[38;5;28mlen\u001b[39m(tokens))\n\u001b[0;32m      9\u001b[0m                     \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m     13\u001b[0m                     \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m     14\u001b[0m                 )\n\u001b[1;32m---> 17\u001b[0m \u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmkdir\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdata\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m     18\u001b[0m export_to_file(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./data/conll_train.txt\u001b[39m\u001b[38;5;124m\"\u001b[39m, conll_data[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[0;32m     19\u001b[0m export_to_file(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./data/conll_val.txt\u001b[39m\u001b[38;5;124m\"\u001b[39m, conll_data[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n",
      "\u001b[1;31mFileExistsError\u001b[0m: [WinError 183] 当文件已存在时，无法创建该文件。: 'data'"
     ]
    }
   ],
   "source": [
    "def export_to_file(export_file_path, data):\n",
    "    with open(export_file_path, \"w\") as f:\n",
    "        for record in data:\n",
    "            ner_tags = record[\"ner_tags\"]\n",
    "            tokens = record[\"tokens\"]\n",
    "            if len(tokens) > 0:\n",
    "                f.write(\n",
    "                    str(len(tokens))\n",
    "                    + \"\\t\"\n",
    "                    + \"\\t\".join(tokens)\n",
    "                    + \"\\t\"\n",
    "                    + \"\\t\".join(map(str, ner_tags))\n",
    "                    + \"\\n\"\n",
    "                )\n",
    "\n",
    "\n",
    "os.mkdir(\"data\")\n",
    "export_to_file(\"./data/conll_train.txt\", conll_data[\"train\"])\n",
    "export_to_file(\"./data/conll_val.txt\", conll_data[\"validation\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88e57d77",
   "metadata": {},
   "source": [
    "## 制作 NER 标签查找表\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e8359a3a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{0: '[PAD]', 1: 'O', 2: 'B-PER', 3: 'I-PER', 4: 'B-ORG', 5: 'I-ORG', 6: 'B-LOC', 7: 'I-LOC', 8: 'B-MISC', 9: 'I-MISC'}\n"
     ]
    }
   ],
   "source": [
    "def make_tag_lookup_table():\n",
    "    iob_labels = [\"B\", \"I\"]\n",
    "    ner_labels = [\"PER\", \"ORG\", \"LOC\", \"MISC\"]\n",
    "    all_labels = [(label1, label2) for label2 in ner_labels for label1 in iob_labels]\n",
    "    all_labels = [\"-\".join([a, b]) for a, b in all_labels]\n",
    "    all_labels = [\"[PAD]\", \"O\"] + all_labels\n",
    "    return dict(zip(range(0, len(all_labels) + 1), all_labels))\n",
    "\n",
    "\n",
    "mapping = make_tag_lookup_table()\n",
    "print(mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c10f0597",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "21009\n"
     ]
    }
   ],
   "source": [
    "#获取训练数据集中所有标记的列表。用于创建词汇表。\n",
    "all_tokens = sum(conll_data[\"train\"][\"tokens\"], [])\n",
    "all_tokens_array = np.array(list(map(str.lower, all_tokens)))\n",
    "\n",
    "counter = Counter(all_tokens_array)\n",
    "print(len(counter))\n",
    "\n",
    "num_tags = len(mapping)\n",
    "vocab_size = 20000\n",
    "\n",
    "vocabulary = [token for token, count in counter.most_common(vocab_size - 2)]\n",
    "\n",
    "lookup_layer = keras.layers.StringLookup(\n",
    "    vocabulary=vocabulary\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "73750ce3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#从训练和验证数据创建 2 个新对象\n",
    "train_data = tf.data.TextLineDataset(\"./data/conll_train.txt\")\n",
    "val_data = tf.data.TextLineDataset(\"./data/conll_val.txt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ba9b03e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#转换数据集中的数据\n",
    "def map_record_to_training_data(record):\n",
    "    record = tf.strings.split(record, sep=\"\\t\")\n",
    "    length = tf.strings.to_number(record[0], out_type=tf.int32)\n",
    "    tokens = record[1 : length + 1]\n",
    "    tags = record[length + 1 :]\n",
    "    tags = tf.strings.to_number(tags, out_type=tf.int64)\n",
    "    tags += 1\n",
    "    return tokens, tags\n",
    "\n",
    "def lowercase_and_convert_to_ids(tokens):\n",
    "    tokens = tf.strings.lower(tokens)\n",
    "    return lookup_layer(tokens)\n",
    "\n",
    "batch_size = 32\n",
    "train_dataset = (\n",
    "    train_data.map(map_record_to_training_data)\n",
    "    .map(lambda x, y: (lowercase_and_convert_to_ids(x), y))\n",
    "    .padded_batch(batch_size)\n",
    ")\n",
    "val_dataset = (\n",
    "    val_data.map(map_record_to_training_data)\n",
    "    .map(lambda x, y: (lowercase_and_convert_to_ids(x), y))\n",
    "    .padded_batch(batch_size)\n",
    ")\n",
    "\n",
    "ner_model = NERModel(num_tags, vocab_size, embed_dim=32, num_heads=4, ff_dim=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5a26c7d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomNonPaddingTokenLoss(keras.losses.Loss):\n",
    "    def __init__(self, name=\"custom_ner_loss\"):\n",
    "        super().__init__(name=name)\n",
    "\n",
    "    def call(self, y_true, y_pred):\n",
    "        loss_fn = keras.losses.SparseCategoricalCrossentropy(\n",
    "            from_logits=True, reduction=keras.losses.Reduction.NONE\n",
    "        )\n",
    "        loss = loss_fn(y_true, y_pred)\n",
    "        mask = tf.cast((y_true > 0), dtype=tf.float32)\n",
    "        loss = loss * mask\n",
    "        return tf.reduce_sum(loss) / tf.reduce_sum(mask)\n",
    "\n",
    "loss = CustomNonPaddingTokenLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b81002d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Miniconda3\\envs\\NLP\\lib\\site-packages\\keras\\backend.py:5582: UserWarning: \"`sparse_categorical_crossentropy` received `from_logits=True`, but the `output` argument was produced by a Softmax activation and thus does not represent logits. Was this intended?\n",
      "  output, from_logits = _get_logits(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "439/439 [==============================] - 10s 17ms/step - loss: 0.6508\n",
      "Epoch 2/10\n",
      "439/439 [==============================] - 6s 14ms/step - loss: 0.2504\n",
      "Epoch 3/10\n",
      "439/439 [==============================] - 6s 14ms/step - loss: 0.1520\n",
      "Epoch 4/10\n",
      "439/439 [==============================] - 6s 14ms/step - loss: 0.1172\n",
      "Epoch 5/10\n",
      "439/439 [==============================] - 6s 14ms/step - loss: 0.0927\n",
      "Epoch 6/10\n",
      "439/439 [==============================] - 6s 14ms/step - loss: 0.0755\n",
      "Epoch 7/10\n",
      "439/439 [==============================] - 6s 14ms/step - loss: 0.0662\n",
      "Epoch 8/10\n",
      "439/439 [==============================] - 6s 14ms/step - loss: 0.0581\n",
      "Epoch 9/10\n",
      "439/439 [==============================] - 6s 14ms/step - loss: 0.0489\n",
      "Epoch 10/10\n",
      "439/439 [==============================] - 6s 14ms/step - loss: 0.0432\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x1c52f37e950>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#编译和训练\n",
    "ner_model.compile(optimizer=\"adam\", loss=loss)\n",
    "ner_model.fit(train_dataset, epochs=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "9a2e568b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_and_convert_to_ids(text):\n",
    "    tokens = text.split()\n",
    "    return lowercase_and_convert_to_ids(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "aa046f49",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor([[  988 10950   204   628     6  3938   215  5773]], shape=(1, 8), dtype=int64)\n",
      "1/1 [==============================] - 0s 26ms/step\n",
      "['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O']\n"
     ]
    }
   ],
   "source": [
    "# 简单测试\n",
    "# for it in input1:\n",
    "sample_input = tokenize_and_convert_to_ids(\n",
    "\"eu rejects german call to boycott british lamb\")\n",
    "sample_input = tf.reshape(sample_input, shape=[1, -1])\n",
    "print(sample_input)\n",
    "output = ner_model.predict(sample_input)\n",
    "prediction = np.argmax(output, axis=-1)[0]\n",
    "# print(it)\n",
    "prediction = [mapping[i] for i in prediction]\n",
    "print(prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a74050de",
   "metadata": {},
   "source": [
    "## 指标计算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "5eecf47a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 205ms/step\n",
      "1/1 [==============================] - 0s 28ms/step\n",
      "1/1 [==============================] - 0s 32ms/step\n",
      "1/1 [==============================] - 0s 30ms/step\n",
      "1/1 [==============================] - 0s 27ms/step\n",
      "1/1 [==============================] - 0s 31ms/step\n",
      "1/1 [==============================] - 0s 31ms/step\n",
      "1/1 [==============================] - 0s 25ms/step\n",
      "1/1 [==============================] - 0s 30ms/step\n",
      "1/1 [==============================] - 0s 32ms/step\n",
      "1/1 [==============================] - 0s 30ms/step\n",
      "1/1 [==============================] - 0s 31ms/step\n",
      "1/1 [==============================] - 0s 32ms/step\n",
      "1/1 [==============================] - 0s 28ms/step\n",
      "1/1 [==============================] - 0s 26ms/step\n",
      "1/1 [==============================] - 0s 32ms/step\n",
      "1/1 [==============================] - 0s 31ms/step\n",
      "1/1 [==============================] - 0s 38ms/step\n",
      "1/1 [==============================] - 0s 30ms/step\n",
      "1/1 [==============================] - 0s 29ms/step\n",
      "1/1 [==============================] - 0s 30ms/step\n",
      "1/1 [==============================] - 0s 29ms/step\n",
      "1/1 [==============================] - 0s 35ms/step\n",
      "1/1 [==============================] - 0s 33ms/step\n",
      "1/1 [==============================] - 0s 35ms/step\n",
      "1/1 [==============================] - 0s 31ms/step\n",
      "1/1 [==============================] - 0s 31ms/step\n",
      "1/1 [==============================] - 0s 30ms/step\n",
      "1/1 [==============================] - 0s 30ms/step\n",
      "1/1 [==============================] - 0s 33ms/step\n",
      "1/1 [==============================] - 0s 35ms/step\n",
      "1/1 [==============================] - 0s 32ms/step\n",
      "1/1 [==============================] - 0s 38ms/step\n",
      "1/1 [==============================] - 0s 32ms/step\n",
      "1/1 [==============================] - 0s 34ms/step\n",
      "1/1 [==============================] - 0s 32ms/step\n",
      "1/1 [==============================] - 0s 30ms/step\n",
      "1/1 [==============================] - 0s 35ms/step\n",
      "1/1 [==============================] - 0s 34ms/step\n",
      "1/1 [==============================] - 0s 32ms/step\n",
      "1/1 [==============================] - 0s 31ms/step\n",
      "1/1 [==============================] - 0s 32ms/step\n",
      "1/1 [==============================] - 0s 31ms/step\n",
      "1/1 [==============================] - 0s 30ms/step\n",
      "1/1 [==============================] - 0s 31ms/step\n",
      "1/1 [==============================] - 0s 34ms/step\n",
      "1/1 [==============================] - 0s 33ms/step\n",
      "1/1 [==============================] - 0s 32ms/step\n",
      "1/1 [==============================] - 0s 35ms/step\n",
      "1/1 [==============================] - 0s 36ms/step\n",
      "1/1 [==============================] - 0s 33ms/step\n",
      "1/1 [==============================] - 0s 58ms/step\n",
      "1/1 [==============================] - 0s 34ms/step\n",
      "1/1 [==============================] - 0s 33ms/step\n",
      "1/1 [==============================] - 0s 29ms/step\n",
      "1/1 [==============================] - 0s 30ms/step\n",
      "1/1 [==============================] - 0s 29ms/step\n",
      "1/1 [==============================] - 0s 31ms/step\n",
      "1/1 [==============================] - 0s 32ms/step\n",
      "1/1 [==============================] - 0s 29ms/step\n",
      "1/1 [==============================] - 0s 33ms/step\n",
      "1/1 [==============================] - 0s 27ms/step\n",
      "1/1 [==============================] - 0s 29ms/step\n",
      "1/1 [==============================] - 0s 26ms/step\n",
      "1/1 [==============================] - 0s 25ms/step\n",
      "1/1 [==============================] - 0s 27ms/step\n",
      "1/1 [==============================] - 0s 35ms/step\n",
      "1/1 [==============================] - 0s 33ms/step\n",
      "1/1 [==============================] - 0s 48ms/step\n",
      "1/1 [==============================] - 0s 27ms/step\n",
      "1/1 [==============================] - 0s 29ms/step\n",
      "1/1 [==============================] - 0s 28ms/step\n",
      "1/1 [==============================] - 0s 35ms/step\n",
      "1/1 [==============================] - 0s 33ms/step\n",
      "1/1 [==============================] - 0s 34ms/step\n",
      "1/1 [==============================] - 0s 30ms/step\n",
      "1/1 [==============================] - 0s 25ms/step\n",
      "1/1 [==============================] - 0s 30ms/step\n",
      "1/1 [==============================] - 0s 36ms/step\n",
      "1/1 [==============================] - 0s 33ms/step\n",
      "1/1 [==============================] - 0s 33ms/step\n",
      "1/1 [==============================] - 0s 48ms/step\n",
      "1/1 [==============================] - 0s 32ms/step\n",
      "1/1 [==============================] - 0s 33ms/step\n",
      "1/1 [==============================] - 0s 35ms/step\n",
      "1/1 [==============================] - 0s 32ms/step\n",
      "1/1 [==============================] - 0s 32ms/step\n",
      "1/1 [==============================] - 0s 33ms/step\n",
      "1/1 [==============================] - 0s 31ms/step\n",
      "1/1 [==============================] - 0s 33ms/step\n",
      "1/1 [==============================] - 0s 40ms/step\n",
      "1/1 [==============================] - 0s 34ms/step\n",
      "1/1 [==============================] - 0s 35ms/step\n",
      "1/1 [==============================] - 0s 30ms/step\n",
      "1/1 [==============================] - 0s 33ms/step\n",
      "1/1 [==============================] - 0s 34ms/step\n",
      "1/1 [==============================] - 0s 35ms/step\n",
      "1/1 [==============================] - 0s 35ms/step\n",
      "1/1 [==============================] - 0s 32ms/step\n",
      "1/1 [==============================] - 0s 29ms/step\n",
      "1/1 [==============================] - 0s 33ms/step\n",
      "1/1 [==============================] - 0s 30ms/step\n",
      "processed 51362 tokens with 5942 phrases; found: 5312 phrases; correct: 3883.\n",
      "accuracy:  63.52%; (non-O)\n",
      "accuracy:  93.55%; precision:  73.10%; recall:  65.35%; FB1:  69.01\n",
      "              LOC: precision:  86.83%; recall:  78.99%; FB1:  82.73  1671\n",
      "             MISC: precision:  72.25%; recall:  66.92%; FB1:  69.48  854\n",
      "              ORG: precision:  64.81%; recall:  62.49%; FB1:  63.63  1293\n",
      "              PER: precision:  65.39%; recall:  53.04%; FB1:  58.57  1494\n"
     ]
    }
   ],
   "source": [
    "def calculate_metrics(dataset):\n",
    "    all_true_tag_ids, all_predicted_tag_ids = [], []\n",
    "\n",
    "    for x, y in dataset:\n",
    "        output = ner_model.predict(x)\n",
    "        predictions = np.argmax(output, axis=-1)\n",
    "        predictions = np.reshape(predictions, [-1])\n",
    "\n",
    "        true_tag_ids = np.reshape(y, [-1])\n",
    "\n",
    "        mask = (true_tag_ids > 0) & (predictions > 0)\n",
    "        true_tag_ids = true_tag_ids[mask]\n",
    "        predicted_tag_ids = predictions[mask]\n",
    "\n",
    "        all_true_tag_ids.append(true_tag_ids)\n",
    "        all_predicted_tag_ids.append(predicted_tag_ids)\n",
    "\n",
    "    all_true_tag_ids = np.concatenate(all_true_tag_ids)\n",
    "    all_predicted_tag_ids = np.concatenate(all_predicted_tag_ids)\n",
    "\n",
    "    predicted_tags = [mapping[tag] for tag in all_predicted_tag_ids]\n",
    "    real_tags = [mapping[tag] for tag in all_true_tag_ids]\n",
    "\n",
    "    evaluate(real_tags, predicted_tags)\n",
    "\n",
    "\n",
    "calculate_metrics(val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "291b934f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
